package cn.nicholasld.nmqs.message;

import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import cn.nicholasld.nmqs.config.WebSocketConfig;
import cn.nicholasld.nmqs.message.mqtt.CommonMQTT;
import cn.nicholasld.nmqs.message.mqtt.MyMQTT;
import com.alibaba.druid.support.json.JSONUtils;
import com.alibaba.druid.util.StringUtils;
import jakarta.websocket.*;
import jakarta.websocket.server.PathParam;
import jakarta.websocket.server.ServerEndpoint;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.util.Base64;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * @author NicholasLD
 * @createTime 2023/3/29 00:33
 */
@ServerEndpoint(value = "/common/{str}", configurator = WebSocketConfig.class)
@Component
@Slf4j
public class CommonWebSocket {
    private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
    private static int onlineCount = 0;
    private static final CopyOnWriteArraySet<CommonWebSocket> webSocketSet =
            new CopyOnWriteArraySet<>();
    private Session session;
    private CommonMQTT mqtt;

    private Integer protocol;
    private String server;
    private String port;
    private String username;
    private String password;
    private String subTopic;
    private String sendTopic;

    /**
     * 连接建立成功调用的方法
     * @param session session为与某个客户端的连接会话，需要通过它来给客户端发送数据
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("str") String str) {
        try {
            String decryptedData = decrypt(str, "1234567890123456");
            JSONObject entries = JSONUtil.parseObj(decryptedData);

            protocol = switch (entries.getStr("protocol")) {
                case "ws" -> 0;
                case "tcp" -> 1;
                case "wss" -> 2;
                default -> throw new IllegalStateException("Unexpected value: " + entries.getStr("protocol"));
            };
            server = entries.getStr("server");
            port = entries.getStr("port");
            username = entries.getStr("username");
            password = entries.getStr("password");
            subTopic = entries.getStr("subTopic");
            sendTopic = entries.getStr("sendTopic");

            this.session = session;
            webSocketSet.add(this);
            addOnlineCount();
            log.info("[Websocket] Websocket有新连接加入，SessionID:{}，当前在线人数为{}", session.getId(), getOnlineCount());

            try {
                sendMessage("[Websocket] 连接成功，SessionID:" + session.getId());
                connectMqttAsync();
            } catch (IOException e) {
                log.error("[Websocket] Websocket IO异常",e);
            }
        }catch (Exception e) {
            log.error("[Websocket] Websocket连接失败");
            try {
                session.close();
            } catch (IOException ioException) {
                log.error("[Websocket] Websocket IO异常",ioException);
            }
        }
    }

    public void connectMqttAsync() {
        executorService.execute(() -> {
            mqtt = new CommonMQTT(session, server, port, username, password, sendTopic, subTopic, protocol);
            mqtt.start();
        });
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        //释放线程
        mqtt.release();
        executorService.shutdown();
        webSocketSet.remove(this);  //从set中删除
        subOnlineCount();           //在线数减1
        log.info("[Websocket] Websocket有一连接关闭，SessionID:{}，当前在线人数为{}", session.getId(), getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     * @param message 客户端发送过来的消息
     * @param session 接收消息的session会话
     */
    @OnMessage
    public void onMessage(String message, Session session) throws MqttException {
        log.info("[Websocket] 来自客户端 SessionID:{} 的消息:{}", session.getId(), message);
        //发送到MQTT
        mqtt.publish(message);
    }

    /**
     * 发生错误时调用
     * @param session session会话
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("[Websocket] Websocket发生错误",error);
    }


    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }

    public static void sendMessage(String message, Session session) throws IOException {
        session.getBasicRemote().sendText(message);
    }


     /**
      * 群发自定义消息
      * */
    public static void sendInfo(String message) throws IOException {
        for (CommonWebSocket item : webSocketSet) {
            try {
                item.sendMessage(message);
            } catch (IOException e) {
                log.error("[Websocket] Websocket群发消息失败",e);
            }
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        CommonWebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        CommonWebSocket.onlineCount--;
    }

    public static String decrypt(String encryptedData, String secretKey) throws Exception {
        // 将下划线替换回斜线
        encryptedData = encryptedData.replace('_', '/');

        // 解码 Base64
        byte[] encryptedBytes = Base64.getDecoder().decode(encryptedData);

        // 提取 IV 和加密数据
        byte[] iv = new byte[16];
        byte[] cipherText = new byte[encryptedBytes.length - 16];
        System.arraycopy(encryptedBytes, 0, iv, 0, 16);
        System.arraycopy(encryptedBytes, 16, cipherText, 0, cipherText.length);

        // 初始化解密器
        Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
        SecretKeySpec keySpec = new SecretKeySpec(secretKey.getBytes("UTF-8"), "AES");
        IvParameterSpec ivSpec = new IvParameterSpec(iv);
        cipher.init(Cipher.DECRYPT_MODE, keySpec, ivSpec);

        // 解密
        byte[] original = cipher.doFinal(cipherText);
        return new String(original, "UTF-8");
    }
}
